{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chaper 8 - Intrinsic Curiosity Module\n",
    "#### Deep Reinforcement Learning *in Action*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/gym/envs/registration.py:307: DeprecationWarning: The package name gym_minigrid has been deprecated in favor of minigrid. Please uninstall gym_minigrid and install minigrid with `pip install minigrid`. Future releases will be maintained under the new package name minigrid.\n",
      "  fn()\n",
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n",
      "  logger.warn(\n"
     ]
    }
   ],
   "source": [
    "from nes_py.wrappers import JoypadSpace #A\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT #B\n",
    "env = gym_super_mario_bros.make('SuperMarioBros-v3', apply_api_compatibility=True, render_mode=\"rgb_array\")\n",
    "env = JoypadSpace(env, COMPLEX_MOVEMENT) #C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n",
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:272: UserWarning: \u001b[33mWARN: No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`.\u001b[0m\n",
      "  logger.warn(\n"
     ]
    }
   ],
   "source": [
    "done = True\n",
    "for step in range(2500): #D\n",
    "    if done:\n",
    "        state = env.reset()\n",
    "    state, reward, done, trunc, info = env.step(env.action_space.sample())\n",
    "    env.render()\n",
    "#env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from skimage.transform import resize #A\n",
    "import numpy as np\n",
    "\n",
    "def downscale_obs(obs, new_size=(42,42), to_gray=True):\n",
    "    if to_gray:\n",
    "        return resize(obs, new_size, anti_aliasing=True).max(axis=2) #B\n",
    "    else:\n",
    "        return resize(obs, new_size, anti_aliasing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fd812f43cd0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGeCAYAAADSRtWEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgj0lEQVR4nO3df3TU1f3n8deEJMOPZAYSIEM2EwqiINLQNUKc1VqESEAPC5LuWrXHaDl6oIEjZLtq9vij9FtP0H4XkYrRbTmg3zXGg2tw8SxQDDKs24RCJF9QaxRkSzwhQV0zE4KZpJnP/tHj1CkkwyQz3kx4Ps655zife+fOO1fOvM5n5nPnY7MsyxIAAN+xJNMFAAAuTwQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEcmmC/hHwWBQLS0tSk9Pl81mM10OACBKlmWpo6ND2dnZSkrq5zzHipPnnnvOmjx5smW32625c+dahw4duqTnNTc3W5JoNBqNluCtubm53/f7uJwBvfbaayorK9MLL7yggoICbdq0SUVFRWpqatLEiRP7fW56erok6S/vfU+OND4hBIBE4z8X1ORr/2/o/bwvNsuK/Y+RFhQUaM6cOXruueck/e1jNbfbrTVr1uiRRx7p97l+v19Op1NffTxVjnQCCAASjb8jqHFXfSqfzyeHw9HnuJi/w3d3d6uhoUGFhYV/f5GkJBUWFqquru6C8YFAQH6/P6wBAIa/mAfQF198od7eXmVlZYUdz8rKUmtr6wXjKyoq5HQ6Q83tdse6JADAEGT8M67y8nL5fL5Qa25uNl0SAOA7EPOLEMaPH68RI0aora0t7HhbW5tcLtcF4+12u+x2e6zLAAAMcTE/A0pNTVV+fr5qa2tDx4LBoGpra+XxeGL9cgCABBWXy7DLyspUUlKi6667TnPnztWmTZvU2dmp++67Lx4vBwBIQHEJoDvuuEOff/65Hn/8cbW2tuoHP/iB9uzZc8GFCQCAy1dc9gENBvuAACCxGdsHBADApSCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYERc7gcEXKozfz0XcUzXJdwwZEpKWgyqAfBd4gwIAGAEAQQAMIIAAgAYQQABAIwggAAARhBAAAAjCCAAgBEEEADACDaiIq4CVk+//Ut+9Z8jzpF+R0vEMf8y/ZWIY3KS2awKDCWcAQEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABjBRlQY1ZVpizjGdzQ74piRMyLPA2BoifkZ0C9/+UvZbLawNmPGjFi/DAAgwcXlDOiaa67R22+//fcXSeZECwAQLi7JkJycLJfLFY+pAQDDRFwuQvjkk0+UnZ2tqVOn6u6779bp06f7HBsIBOT3+8MaAGD4i3kAFRQUaPv27dqzZ48qKyt16tQp/fCHP1RHR8dFx1dUVMjpdIaa2+2OdUkAgCHIZlmWFc8XaG9v1+TJk7Vx40atWLHigv5AIKBAIBB67Pf75Xa79dXHU+VI5yrxRBfpdgz5mx+MPEdm5H+idXf+c8Qx40eMiTgGwOD5O4Iad9Wn8vl8cjgcfY6L+9UBY8eO1VVXXaUTJ05ctN9ut8tut8e7DADAEBP3U4xz587p5MmTmjRpUrxfCgCQQGJ+BvSLX/xCS5Ys0eTJk9XS0qInnnhCI0aM0J133hnrl0ICsNtS+u3/P2v+a8Q5gpfwKfE4Pl4DEk7MA+izzz7TnXfeqS+//FITJkzQjTfeqPr6ek2YMCHWLwUASGAxD6Dq6upYTwkAGIa4zAwAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACO4UQ+MciaNMl0CAEM4AwIAGEEAAQCMIIAAAEYQQAAAIwggAIARBBAAwAgCCABgBAEEADCCAAIAGEEAAQCMIIAAAEYQQAAAIwggAIARBBAAwAgCCABgBAEEADCCAAIAGEEAAQCMIIAAAEYQQAAAIwggAIARBBAAwAgCCABgBAEEADCCAAIAGEEAAQCMiDqADh48qCVLlig7O1s2m007d+4M67csS48//rgmTZqkUaNGqbCwUJ988kms6gUADBNRB1BnZ6dmz56tLVu2XLT/6aef1ubNm/XCCy/o0KFDGjNmjIqKitTV1TXoYgEAw0dytE9YvHixFi9efNE+y7K0adMmPfroo1q6dKkk6eWXX1ZWVpZ27typn/zkJ4OrFgAwbMT0O6BTp06ptbVVhYWFoWNOp1MFBQWqq6u76HMCgYD8fn9YAwAMfzENoNbWVklSVlZW2PGsrKxQ3z+qqKiQ0+kMNbfbHcuSAABDlPGr4MrLy+Xz+UKtubnZdEkAgO9ATAPI5XJJktra2sKOt7W1hfr+kd1ul8PhCGsAgOEvpgE0ZcoUuVwu1dbWho75/X4dOnRIHo8nli8FAEhwUV8Fd+7cOZ04cSL0+NSpU2psbFRGRoZyc3O1du1a/frXv9aVV16pKVOm6LHHHlN2draWLVsWy7oBAAku6gA6cuSIbr755tDjsrIySVJJSYm2b9+uhx56SJ2dnXrggQfU3t6uG2+8UXv27NHIkSNjVzUAIOHZLMuyTBfxbX6/X06nU199PFWOdOPXSAAAouTvCGrcVZ/K5/P1+70+7/AAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACOiDqCDBw9qyZIlys7Ols1m086dO8P67733XtlstrC2aNGiWNULABgmog6gzs5OzZ49W1u2bOlzzKJFi3TmzJlQe/XVVwdVJABg+EmO9gmLFy/W4sWL+x1jt9vlcrkGXBQAYPiLy3dABw4c0MSJEzV9+nStWrVKX375ZZ9jA4GA/H5/WAMADH8xD6BFixbp5ZdfVm1trZ566il5vV4tXrxYvb29Fx1fUVEhp9MZam63O9YlAQCGIJtlWdaAn2yzqaamRsuWLetzzKeffqorrrhCb7/9thYsWHBBfyAQUCAQCD32+/1yu9366uOpcqRzkR4AJBp/R1DjrvpUPp9PDoejz3Fxf4efOnWqxo8frxMnTly03263y+FwhDUAwPAX9wD67LPP9OWXX2rSpEnxfikAQAKJ+iq4c+fOhZ3NnDp1So2NjcrIyFBGRobWr1+v4uJiuVwunTx5Ug899JCmTZumoqKimBYOAEhsUQfQkSNHdPPNN4cel5WVSZJKSkpUWVmpY8eO6aWXXlJ7e7uys7O1cOFC/dM//ZPsdnvsqgYAJLyoA2jevHnq77qFvXv3DqogAMDlgcvMAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGRBVAFRUVmjNnjtLT0zVx4kQtW7ZMTU1NYWO6urpUWlqqzMxMpaWlqbi4WG1tbTEtGgCQ+KIKIK/Xq9LSUtXX12vfvn3q6enRwoUL1dnZGRqzbt067dq1Szt27JDX61VLS4uWL18e88IBAInNZlmWNdAnf/7555o4caK8Xq9uuukm+Xw+TZgwQVVVVfrxj38sSfroo4909dVXq66uTtdff33EOf1+v5xOp776eKoc6XxCCACJxt8R1LirPpXP55PD4ehz3KDe4X0+nyQpIyNDktTQ0KCenh4VFhaGxsyYMUO5ubmqq6u76ByBQEB+vz+sAQCGvwEHUDAY1Nq1a3XDDTdo1qxZkqTW1lalpqZq7NixYWOzsrLU2tp60XkqKirkdDpDze12D7QkAEACGXAAlZaW6v3331d1dfWgCigvL5fP5wu15ubmQc0HAEgMyQN50urVq/XWW2/p4MGDysnJCR13uVzq7u5We3t72FlQW1ubXC7XReey2+2y2+0DKQMAkMCiOgOyLEurV69WTU2N9u/frylTpoT15+fnKyUlRbW1taFjTU1NOn36tDweT2wqBgAMC1GdAZWWlqqqqkpvvvmm0tPTQ9/rOJ1OjRo1Sk6nUytWrFBZWZkyMjLkcDi0Zs0aeTyeS7oCDgBw+YgqgCorKyVJ8+bNCzu+bds23XvvvZKkZ555RklJSSouLlYgEFBRUZGef/75mBQLABg+BrUPKB7YBwQAie072QcEAMBAEUAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGRBVAFRUVmjNnjtLT0zVx4kQtW7ZMTU1NYWPmzZsnm80W1lauXBnTogEAiS+qAPJ6vSotLVV9fb327dunnp4eLVy4UJ2dnWHj7r//fp05cybUnn766ZgWDQBIfMnRDN6zZ0/Y4+3bt2vixIlqaGjQTTfdFDo+evRouVyu2FQIABiWBvUdkM/nkyRlZGSEHX/llVc0fvx4zZo1S+Xl5Tp//nyfcwQCAfn9/rAGABj+ojoD+rZgMKi1a9fqhhtu0KxZs0LH77rrLk2ePFnZ2dk6duyYHn74YTU1NemNN9646DwVFRVav379QMsAACQom2VZ1kCeuGrVKu3evVvvvvuucnJy+hy3f/9+LViwQCdOnNAVV1xxQX8gEFAgEAg99vv9crvd+urjqXKkc5EeACQaf0dQ4676VD6fTw6Ho89xAzoDWr16td566y0dPHiw3/CRpIKCAknqM4DsdrvsdvtAygAAJLCoAsiyLK1Zs0Y1NTU6cOCApkyZEvE5jY2NkqRJkyYNqEAAwPAUVQCVlpaqqqpKb775ptLT09Xa2ipJcjqdGjVqlE6ePKmqqirdeuutyszM1LFjx7Ru3TrddNNNysvLi8sfAABITFF9B2Sz2S56fNu2bbr33nvV3Nysn/70p3r//ffV2dkpt9ut22+/XY8++mi/nwN+m9/vl9Pp5DsgAEhQcfkOKFJWud1ueb3eaKYEAFymOMUAABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMCIAf8YqWln/nou4ph/7c4c9OtcZ/9/EceMHzFm0K8DAJcbzoAAAEYQQAAAIwggAIARBBAAwAgCCABgBAEEADCCAAIAGEEAAQCMGLIbUXutoHr7uf3Qoo0PRZwjbVFrv/09vSMizvHFyYyIYz798YsRxwAAwnEGBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMCIIbsPKJIJ/9oVcczz6/6l3/7zwch/fsnb6yKO6bWCEceMsJH1APBtvCsCAIwggAAARhBAAAAjCCAAgBEEEADACAIIAGAEAQQAMIIAAgAYkbAbUTty7BHH/O7LG/vt/6I7LeIc2b8/HnHM9OzSiGNk6+fuepcoaI88x6ZF/W++laR/P+b8oGsBgMGK6gyosrJSeXl5cjgccjgc8ng82r17d6i/q6tLpaWlyszMVFpamoqLi9XW1hbzogEAiS+qAMrJydGGDRvU0NCgI0eOaP78+Vq6dKk++OADSdK6deu0a9cu7dixQ16vVy0tLVq+fHlcCgcAJLaoPoJbsmRJ2OMnn3xSlZWVqq+vV05OjrZu3aqqqirNnz9fkrRt2zZdffXVqq+v1/XXXx+7qgEACW/AFyH09vaqurpanZ2d8ng8amhoUE9PjwoLC0NjZsyYodzcXNXV1fU5TyAQkN/vD2sAgOEv6gA6fvy40tLSZLfbtXLlStXU1GjmzJlqbW1Vamqqxo4dGzY+KytLra2tfc5XUVEhp9MZam63O+o/AgCQeKIOoOnTp6uxsVGHDh3SqlWrVFJSog8//HDABZSXl8vn84Vac3PzgOcCACSOqC/DTk1N1bRp0yRJ+fn5Onz4sJ599lndcccd6u7uVnt7e9hZUFtbm1wuV5/z2e122e2RL6kGAAwvg96IGgwGFQgElJ+fr5SUFNXW1ob6mpqadPr0aXk8nsG+DABgmInqDKi8vFyLFy9Wbm6uOjo6VFVVpQMHDmjv3r1yOp1asWKFysrKlJGRIYfDoTVr1sjj8QzoCrgRtqR+7yL6219tjjjHxpaifvvf/58zIs4x4d91RxzjOhT5jqixYOuNPOahr0oijvl+yW8ijpmSEnmTLgAMRlQBdPbsWd1zzz06c+aMnE6n8vLytHfvXt1yyy2SpGeeeUZJSUkqLi5WIBBQUVGRnn/++bgUDgBIbFEF0NatW/vtHzlypLZs2aItW7YMqigAwPDHj5ECAIwggAAARhBAAAAjCCAAgBEEEADACAIIAGBEwt4RdXZq5DFNr/a/0XTCJ5E3mXY7Iy+RZYtcS0yMiDzEVffXiGNW/q+fRxzzP17/bxHHpCWNjFwQYECvFXlzeFCDv0txIkqxXcIbyXeEMyAAgBEEEADACAIIAGAEAQQAMIIAAgAYQQABAIwggAAARhBAAAAjhuxGVF/wa1nBvvPxhuf+U8Q5st/r7Lf/nHtU1HUNdd2OyJvMkr+O/L/95sfWRRxTsz7ynVVzkrmzKmLrUjaZXn3wvohj7A1jIr/Yd7XJPEZsl3Bz5qxbm/vt33f1rhhVExlnQAAAIwggAIARBBAAwAgCCABgBAEEADCCAAIAGEEAAQCMIIAAAEYM2Y2o93xyu5LH2Pvsd9V3RZxjOG40jYXzWZFvJ5vW0hNxzH/4oCTimB3XvNRvPxtV8W0BK/K/u998+f2IY3ra+37v+MbuNU9HHJOSYBtRO/vZvP+N5dt+0W//C5P+TcQ57nf2v5n1UjYLS5wBAQAMIYAAAEYQQAAAIwggAIARBBAAwAgCCABgBAEEADBiyO4D6t7iUjBlZJ/9wYzIF+jbglYsS7qs9KRFvrHdyGczIo4p/Nmqfvvf9bxwyTVh+Ptj14SIY3Z8+m8jjqm8pf/9Z5eqJ8HeQlIv4Y50v777v/fb/4uD/zHiHHMWPN9v/7nuOOwDqqysVF5enhwOhxwOhzwej3bv3h3qnzdvnmw2W1hbuXJlNC8BALhMRHUGlJOTow0bNujKK6+UZVl66aWXtHTpUh09elTXXHONJOn+++/Xr371q9BzRo8eHduKAQDDQlQBtGTJkrDHTz75pCorK1VfXx8KoNGjR8vlcsWuQgDAsDTgixB6e3tVXV2tzs5OeTye0PFXXnlF48eP16xZs1ReXq7z58/3O08gEJDf7w9rAIDhL+qLEI4fPy6Px6Ouri6lpaWppqZGM2fOlCTdddddmjx5srKzs3Xs2DE9/PDDampq0htvvNHnfBUVFVq/fv3A/wIAQEKKOoCmT5+uxsZG+Xw+vf766yopKZHX69XMmTP1wAMPhMZ9//vf16RJk7RgwQKdPHlSV1xxxUXnKy8vV1lZWeix3++X2+0ewJ8CAEgkUQdQamqqpk2bJknKz8/X4cOH9eyzz+rFF1+8YGxBQYEk6cSJE30GkN1ul90e+afTAQDDy6A3ogaDQQUCgYv2NTY2SpImTZo02JcBAAwzNsuyLnmrVXl5uRYvXqzc3Fx1dHSoqqpKTz31lPbu3aupU6eqqqpKt956qzIzM3Xs2DGtW7dOOTk58nq9l1yQ3++X0+nUPC1Vsi1lQH8UvhtJI/veKPyNr378g377u8bxYxz4O+sSPpPp+EHkm1GmOb+OQTWXp46zkW8S6fig//fm3kCX/lz5X+Tz+eRwOPocF9VHcGfPntU999yjM2fOyOl0Ki8vT3v37tUtt9yi5uZmvf3229q0aZM6OzvldrtVXFysRx99NJqXAABcJqIKoK1bt/bZ53a7ozrTAQBc3vj8AwBgBAEEADCCAAIAGEEAAQCMIIAAAEYQQAAAI4bsHVGv+d822dMi3/UUJl38FzDCHYp7FRg+kmwJdgvS4WjyJYyZ03934FyP/lwZeRrOgAAARhBAAAAjCCAAgBEEEADACAIIAGAEAQQAMIIAAgAYQQABAIwYshtRU2y9SrGRjwCQaEYoeEnjeIcHABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABhBAAEAjCCAAABGEEAAACMIIACAEQQQAMAIAggAYAQBBAAwggACABgxqADasGGDbDab1q5dGzrW1dWl0tJSZWZmKi0tTcXFxWpraxtsnQCAYWbAAXT48GG9+OKLysvLCzu+bt067dq1Szt27JDX61VLS4uWL18+6EIBAMPLgALo3Llzuvvuu/W73/1O48aNCx33+XzaunWrNm7cqPnz5ys/P1/btm3TH//4R9XX18esaABA4htQAJWWluq2225TYWFh2PGGhgb19PSEHZ8xY4Zyc3NVV1d30bkCgYD8fn9YAwAMf8nRPqG6ulrvvfeeDh8+fEFfa2urUlNTNXbs2LDjWVlZam1tveh8FRUVWr9+fbRlAAASXFRnQM3NzXrwwQf1yiuvaOTIkTEpoLy8XD6fL9Sam5tjMi8AYGiLKoAaGhp09uxZXXvttUpOTlZycrK8Xq82b96s5ORkZWVlqbu7W+3t7WHPa2trk8vluuicdrtdDocjrAEAhr+oPoJbsGCBjh8/Hnbsvvvu04wZM/Twww/L7XYrJSVFtbW1Ki4uliQ1NTXp9OnT8ng8sasaAJDwogqg9PR0zZo1K+zYmDFjlJmZGTq+YsUKlZWVKSMjQw6HQ2vWrJHH49H1118fu6oBAAkv6osQInnmmWeUlJSk4uJiBQIBFRUV6fnnn4/1ywAAEpzNsizLdBHf5vf75XQ6tfbdJbKnpZguBwAQpcC5Hm26cZd8Pl+/3+vzW3AAACMIIACAEQQQAMAIAggAYAQBBAAwIuaXYQ/WNxflBTp7DFcCABiIb96/I11kPeQuw/7ss8/kdrtNlwEAGKTm5mbl5OT02T/kAigYDKqlpUXp6emy2WyS/rY3yO12q7m5md+KiwPWN75Y3/hifeNrIOtrWZY6OjqUnZ2tpKS+v+kZch/BJSUl9ZmY/FhpfLG+8cX6xhfrG1/Rrq/T6Yw4hosQAABGEEAAACMSIoDsdrueeOIJ2e1206UMS6xvfLG+8cX6xlc813fIXYQAALg8JMQZEABg+CGAAABGEEAAACMIIACAEUM+gLZs2aLvfe97GjlypAoKCvSnP/3JdEkJ6eDBg1qyZImys7Nls9m0c+fOsH7LsvT4449r0qRJGjVqlAoLC/XJJ5+YKTYBVVRUaM6cOUpPT9fEiRO1bNkyNTU1hY3p6upSaWmpMjMzlZaWpuLiYrW1tRmqOLFUVlYqLy8vtBnS4/Fo9+7doX7WNrY2bNggm82mtWvXho7FY42HdAC99tprKisr0xNPPKH33ntPs2fPVlFRkc6ePWu6tITT2dmp2bNna8uWLRftf/rpp7V582a98MILOnTokMaMGaOioiJ1dXV9x5UmJq/Xq9LSUtXX12vfvn3q6enRwoUL1dnZGRqzbt067dq1Szt27JDX61VLS4uWL19usOrEkZOTow0bNqihoUFHjhzR/PnztXTpUn3wwQeSWNtYOnz4sF588UXl5eWFHY/LGltD2Ny5c63S0tLQ497eXis7O9uqqKgwWFXik2TV1NSEHgeDQcvlclm/+c1vQsfa29stu91uvfrqqwYqTHxnz561JFler9eyrL+tZ0pKirVjx47QmD//+c+WJKuurs5UmQlt3Lhx1u9//3vWNoY6OjqsK6+80tq3b5/1ox/9yHrwwQcty4rfv98hewbU3d2thoYGFRYWho4lJSWpsLBQdXV1Bisbfk6dOqXW1tawtXY6nSooKGCtB8jn80mSMjIyJEkNDQ3q6ekJW+MZM2YoNzeXNY5Sb2+vqqur1dnZKY/Hw9rGUGlpqW677bawtZTi9+93yP0Y6Te++OIL9fb2KisrK+x4VlaWPvroI0NVDU+tra2SdNG1/qYPly4YDGrt2rW64YYbNGvWLEl/W+PU1FSNHTs2bCxrfOmOHz8uj8ejrq4upaWlqaamRjNnzlRjYyNrGwPV1dV67733dPjw4Qv64vXvd8gGEJCoSktL9f777+vdd981XcqwMn36dDU2Nsrn8+n1119XSUmJvF6v6bKGhebmZj344IPat2+fRo4c+Z297pD9CG78+PEaMWLEBVdZtLW1yeVyGapqePpmPVnrwVu9erXeeustvfPOO2G3FXG5XOru7lZ7e3vYeNb40qWmpmratGnKz89XRUWFZs+erWeffZa1jYGGhgadPXtW1157rZKTk5WcnCyv16vNmzcrOTlZWVlZcVnjIRtAqampys/PV21tbehYMBhUbW2tPB6PwcqGnylTpsjlcoWttd/v16FDh1jrS2RZllavXq2amhrt379fU6ZMCevPz89XSkpK2Bo3NTXp9OnTrPEABYNBBQIB1jYGFixYoOPHj6uxsTHUrrvuOt19992h/47LGg/yoom4qq6utux2u7V9+3brww8/tB544AFr7NixVmtrq+nSEk5HR4d19OhR6+jRo5Yka+PGjdbRo0etv/zlL5ZlWdaGDRussWPHWm+++aZ17Ngxa+nSpdaUKVOsr7/+2nDliWHVqlWW0+m0Dhw4YJ05cybUzp8/HxqzcuVKKzc319q/f7915MgRy+PxWB6Px2DVieORRx6xvF6vderUKevYsWPWI488YtlsNusPf/iDZVmsbTx8+yo4y4rPGg/pALIsy/rtb39r5ebmWqmpqdbcuXOt+vp60yUlpHfeeceSdEErKSmxLOtvl2I/9thjVlZWlmW3260FCxZYTU1NZotOIBdbW0nWtm3bQmO+/vpr6+c//7k1btw4a/To0dbtt99unTlzxlzRCeRnP/uZNXnyZCs1NdWaMGGCtWDBglD4WBZrGw//GEDxWGNuxwAAMGLIfgcEABjeCCAAgBEEEADACAIIAGAEAQQAMIIAAgAYQQABAIwggAAARhBAAAAjCCAAgBEEEADACAIIAGDE/we+y0bYt47onwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(env.render())\n",
    "plt.imshow(downscale_obs(env.render()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from collections import deque\n",
    "\n",
    "def prepare_state(state): #A\n",
    "    return torch.from_numpy(downscale_obs(state, to_gray=True)).float().unsqueeze(dim=0)\n",
    "\n",
    "\n",
    "def prepare_multi_state(state1, state2): #B\n",
    "    state1 = state1.clone()\n",
    "    tmp = torch.from_numpy(downscale_obs(state2, to_gray=True)).float()\n",
    "    state1[0][0] = state1[0][1]\n",
    "    state1[0][1] = state1[0][2]\n",
    "    state1[0][2] = tmp\n",
    "    return state1\n",
    "\n",
    "\n",
    "def prepare_initial_state(state,N=3): #C\n",
    "    state_ = torch.from_numpy(downscale_obs(state, to_gray=True)).float()\n",
    "    tmp = state_.repeat((N,1,1))\n",
    "    return tmp.unsqueeze(dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy(qvalues, eps=None): #A\n",
    "    if eps is not None:\n",
    "        if torch.rand(1) < eps:\n",
    "            return torch.randint(low=0,high=7,size=(1,))\n",
    "        else:\n",
    "            return torch.argmax(qvalues)\n",
    "    else:\n",
    "        return torch.multinomial(F.softmax(F.normalize(qvalues)), num_samples=1) #B"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import shuffle\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class ExperienceReplay:\n",
    "    def __init__(self, N=500, batch_size=100):\n",
    "        self.N = N #A\n",
    "        self.batch_size = batch_size #B\n",
    "        self.memory = [] \n",
    "        self.counter = 0\n",
    "        \n",
    "    def add_memory(self, state1, action, reward, state2):\n",
    "        self.counter +=1 \n",
    "        if self.counter % 500 == 0: #C\n",
    "            self.shuffle_memory()\n",
    "            \n",
    "        if len(self.memory) < self.N: #D\n",
    "            self.memory.append( (state1, action, reward, state2) )\n",
    "        else:\n",
    "            rand_index = np.random.randint(0,self.N-1)\n",
    "            self.memory[rand_index] = (state1, action, reward, state2)\n",
    "    \n",
    "    def shuffle_memory(self): #E\n",
    "        shuffle(self.memory)\n",
    "        \n",
    "    def get_batch(self): #F\n",
    "        if len(self.memory) < self.batch_size:\n",
    "            batch_size = len(self.memory)\n",
    "        else:\n",
    "            batch_size = self.batch_size\n",
    "        if len(self.memory) < 1:\n",
    "            print(\"Error: No data in memory.\")\n",
    "            return None\n",
    "        #G\n",
    "        ind = np.random.choice(np.arange(len(self.memory)),batch_size,replace=False)\n",
    "        batch = [self.memory[i] for i in ind] #batch is a list of tuples\n",
    "        state1_batch = torch.stack([x[0].squeeze(dim=0) for x in batch],dim=0)\n",
    "        action_batch = torch.Tensor([x[1] for x in batch]).long()\n",
    "        reward_batch = torch.Tensor([x[2] for x in batch])\n",
    "        state2_batch = torch.stack([x[3].squeeze(dim=0) for x in batch],dim=0)\n",
    "        return state1_batch, action_batch, reward_batch, state2_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Phi(nn.Module): #A\n",
    "    def __init__(self):\n",
    "        super(Phi, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x)\n",
    "        y = F.elu(self.conv1(x))\n",
    "        y = F.elu(self.conv2(y))\n",
    "        y = F.elu(self.conv3(y))\n",
    "        y = F.elu(self.conv4(y)) #size [1, 32, 3, 3] batch, channels, 3 x 3\n",
    "        y = y.flatten(start_dim=1) #size N, 288\n",
    "        return y\n",
    "\n",
    "class Gnet(nn.Module): #B\n",
    "    def __init__(self):\n",
    "        super(Gnet, self).__init__()\n",
    "        self.linear1 = nn.Linear(576,256)\n",
    "        self.linear2 = nn.Linear(256,12)\n",
    "\n",
    "    def forward(self, state1,state2):\n",
    "        x = torch.cat( (state1, state2) ,dim=1)\n",
    "        y = F.relu(self.linear1(x))\n",
    "        y = self.linear2(y)\n",
    "        y = F.softmax(y,dim=1)\n",
    "        return y\n",
    "\n",
    "class Fnet(nn.Module): #C\n",
    "    def __init__(self):\n",
    "        super(Fnet, self).__init__()\n",
    "        self.linear1 = nn.Linear(300,256)\n",
    "        self.linear2 = nn.Linear(256,288)\n",
    "\n",
    "    def forward(self,state,action):\n",
    "        action_ = torch.zeros(action.shape[0],12) #D\n",
    "        indices = torch.stack( (torch.arange(action.shape[0]), action.squeeze()), dim=0)\n",
    "        indices = indices.tolist()\n",
    "        action_[indices] = 1.\n",
    "        x = torch.cat( (state,action_) ,dim=1)\n",
    "        y = F.relu(self.linear1(x))\n",
    "        y = self.linear2(y)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Qnetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Qnetwork, self).__init__()\n",
    "        #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.linear1 = nn.Linear(288,100)\n",
    "        self.linear2 = nn.Linear(100,12)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x)\n",
    "        y = F.elu(self.conv1(x))\n",
    "        y = F.elu(self.conv2(y))\n",
    "        y = F.elu(self.conv3(y))\n",
    "        y = F.elu(self.conv4(y))\n",
    "        y = y.flatten(start_dim=2)\n",
    "        y = y.view(y.shape[0], -1, 32)\n",
    "        y = y.flatten(start_dim=1)\n",
    "        y = F.elu(self.linear1(y))\n",
    "        y = self.linear2(y) #size N, 12\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'batch_size':150,\n",
    "    'beta':0.2,\n",
    "    'lambda':0.1,\n",
    "    'eta': 1.0,\n",
    "    'gamma':0.2,\n",
    "    'max_episode_len':100,\n",
    "    'min_progress':15,\n",
    "    'action_repeats':6,\n",
    "    'frames_per_state':3\n",
    "}\n",
    "\n",
    "replay = ExperienceReplay(N=1000, batch_size=params['batch_size'])\n",
    "Qmodel = Qnetwork()\n",
    "encoder = Phi()\n",
    "forward_model = Fnet()\n",
    "inverse_model = Gnet()\n",
    "forward_loss = nn.MSELoss(reduction='none')\n",
    "inverse_loss = nn.CrossEntropyLoss(reduction='none')\n",
    "qloss = nn.MSELoss()\n",
    "all_model_params = list(Qmodel.parameters()) + list(encoder.parameters()) #A\n",
    "all_model_params += list(forward_model.parameters()) + list(inverse_model.parameters())\n",
    "opt = optim.Adam(lr=0.001, params=all_model_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(q_loss, inverse_loss, forward_loss):\n",
    "    loss_ = (1 - params['beta']) * inverse_loss\n",
    "    loss_ += params['beta'] * forward_loss\n",
    "    loss_ = loss_.sum() / loss_.flatten().shape[0]\n",
    "    loss = loss_ + params['lambda'] * q_loss\n",
    "    return loss\n",
    "\n",
    "def reset_env():\n",
    "    \"\"\"\n",
    "    Reset the environment and return a new initial state\n",
    "    \"\"\"\n",
    "    env.reset()\n",
    "    state1 = prepare_initial_state(env.render())\n",
    "    return state1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ICM(state1, action, state2, forward_scale=1., inverse_scale=1e4):\n",
    "    state1_hat = encoder(state1) #A\n",
    "    state2_hat = encoder(state2)\n",
    "    state2_hat_pred = forward_model(state1_hat.detach(), action.detach()) #B\n",
    "    forward_pred_err = forward_scale * forward_loss(state2_hat_pred, \\\n",
    "                        state2_hat.detach()).sum(dim=1).unsqueeze(dim=1)\n",
    "    pred_action = inverse_model(state1_hat, state2_hat) #C\n",
    "    inverse_pred_err = inverse_scale * inverse_loss(pred_action, \\\n",
    "                                        action.detach().flatten()).unsqueeze(dim=1)\n",
    "    return forward_pred_err, inverse_pred_err"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def minibatch_train(use_extrinsic=True):\n",
    "    state1_batch, action_batch, reward_batch, state2_batch = replay.get_batch() \n",
    "    action_batch = action_batch.view(action_batch.shape[0],1) #A\n",
    "    reward_batch = reward_batch.view(reward_batch.shape[0],1)\n",
    "    \n",
    "    forward_pred_err, inverse_pred_err = ICM(state1_batch, action_batch, state2_batch) #B\n",
    "    i_reward = (1. / params['eta']) * forward_pred_err #C\n",
    "    reward = i_reward.detach() #D\n",
    "    if use_explicit: #E\n",
    "        reward += reward_batch \n",
    "    qvals = Qmodel(state2_batch) #F\n",
    "    reward += params['gamma'] * torch.max(qvals)\n",
    "    reward_pred = Qmodel(state1_batch)\n",
    "    reward_target = reward_pred.clone()\n",
    "    indices = torch.stack( (torch.arange(action_batch.shape[0]), \\\n",
    "    action_batch.squeeze()), dim=0)\n",
    "    indices = indices.tolist()\n",
    "    reward_target[indices] = reward.squeeze()\n",
    "    q_loss = 1e5 * qloss(F.normalize(reward_pred), F.normalize(reward_target.detach()))\n",
    "    return forward_pred_err, inverse_pred_err, q_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.13"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_6414/227893001.py:8: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return torch.multinomial(F.softmax(F.normalize(qvalues)), num_samples=1) #B\n"
     ]
    }
   ],
   "source": [
    "epochs = 5000\n",
    "env.reset()\n",
    "state1 = prepare_initial_state(env.render())\n",
    "eps=0.15\n",
    "losses = []\n",
    "episode_length = 0\n",
    "switch_to_eps_greedy = 1000\n",
    "state_deque = deque(maxlen=params['frames_per_state'])\n",
    "e_reward = 0.\n",
    "last_x_pos = 0 #A\n",
    "#ep_lengths = []\n",
    "use_explicit = False\n",
    "for i in range(epochs):\n",
    "    opt.zero_grad()\n",
    "    episode_length += 1\n",
    "    q_val_pred = Qmodel(state1) #B\n",
    "    if i > switch_to_eps_greedy: #C\n",
    "        action = int(policy(q_val_pred,eps))\n",
    "    else:\n",
    "        action = int(policy(q_val_pred))\n",
    "    for j in range(params['action_repeats']): #D\n",
    "        state2, e_reward_, done, trunc, info = env.step(action)\n",
    "        last_x_pos = info['x_pos']\n",
    "        if done:\n",
    "            state1 = reset_env()\n",
    "            break\n",
    "        e_reward += e_reward_\n",
    "        state_deque.append(prepare_state(state2))\n",
    "    state2 = torch.stack(list(state_deque),dim=1) #E\n",
    "    replay.add_memory(state1, action, e_reward, state2) #F\n",
    "    e_reward = 0\n",
    "    if episode_length > params['max_episode_len']: #G\n",
    "        if (info['x_pos'] - last_x_pos) < params['min_progress']:\n",
    "            done = True\n",
    "        else:\n",
    "            last_x_pos = info['x_pos']\n",
    "    if done or trunc:\n",
    "        #ep_lengths.append(info['x_pos'])\n",
    "        state1 = reset_env()\n",
    "        last_x_pos = 0\n",
    "        episode_length = 0\n",
    "    else:\n",
    "        state1 = state2\n",
    "    if len(replay.memory) < params['batch_size']:\n",
    "        continue\n",
    "    forward_pred_err, inverse_pred_err, q_loss = minibatch_train(use_extrinsic=False) #H\n",
    "    loss = loss_fn(q_loss, forward_pred_err, inverse_pred_err) #I\n",
    "    loss_list = (q_loss.mean(), forward_pred_err.flatten().mean(),\\\n",
    "    inverse_pred_err.flatten().mean())\n",
    "    losses.append(loss_list)\n",
    "    loss.backward()\n",
    "    opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "done = True\n",
    "state_deque = deque(maxlen=params['frames_per_state'])\n",
    "for step in range(500):\n",
    "    if done:\n",
    "        env.reset()\n",
    "        state1 = prepare_initial_state(env.render())\n",
    "    q_val_pred = Qmodel(state1)\n",
    "    action = int(policy(q_val_pred,eps))\n",
    "    state2, reward, done, trunc, info = env.step(action)\n",
    "    state2 = prepare_multi_state(state1,state2)\n",
    "    state1=state2\n",
    "    env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
